import datetime
import os
import backtrader as bt
import pandas as pd

stk_num = 1             # 回测股票数目


# 创建策略
class BollStrategy(bt.Strategy):
    # 可配置策略参数
    params = dict(
        p_perion_volume=10,   # 前n日最大交易量
        p_sell_ma=5,        # 跌破该均线卖出
        p_oneplot=False,    # 是否打印到同一张图上
        pstake=1000,        # 单笔交易股票数

    )

    """这里定义了一个python字典类型变量self.inds，用于存储不同股票数据的技术指标，该字典的key为
    单支股票的数据，即代码中的d，value对应的该股票对应的技术指标，这些技术指标也存在一个字典内"""
    def __init__(self):
        self.inds = dict()      # 定义一个字典与实例对象进行关联的参数
        for i, d in enumerate(self.datas):
            self.inds[d] = dict()   # 相当于增加一列参数d与datas中数据进行对应
            boll_mid = bt.ind.BBands(d.colse).mid   # 布林中轨
            # 买入条件  突破中轨 ，放量
            self.inds[d]['buy_con'] = bt.And(d.open < boll_mid, d.close > boll_mid,
                                    d.volume == bt.ind.Highest(d.volume,
                                      period=self.p.p_period_volume, plot=False))
           # 卖出条件
            self.inds[d]['sell_con'] = d.close < bt.ind.SMA(d.close,
                                    period=self.p.p_sell_ma)

            # 如果多支股票回测，跳过第一支股票data,第一只股票data作为主图数据
            if i > 0:
                if self.p.p_oneplot:
                    d.plotinfo.plotmaster = self.datas[0]

    def __next__(self):
        global dt, dn
        for i, d in enumerate(self.datas):
            dt, dn = self.datetime.date(), d._name
            pos = self.getpositon(d).size
            if not pos:
                if self.inds[d]['buy_con']:
                    self.buy(data=d, size=self.p.pstake)
                elif self.inds[d]['sell_con']:
                    self.close(data=d)

    def notity_trade(self, trade):
        # dt = self.data.datetime.day()
        if trade.isclosed:
            print('\033[32m 日期：{}，股票代码：{},SELL EXECUTED,Price:{:.2f},'
                  'Gross:{:.2f},Net:{:.2f} \033[0m'
                  .format(dt, trade.data._name, trade.executed.price,
                          round(trade.pnl, 2), round(trade.pnlcomm, 2)))


# 创建cerebro
cerebro = bt.Cerebro()

# 读入股票代码
stk_code_file = '../data/code.cvs'
stk_pools = pd.read_csv(stk_code_file)

for i in range(stk_num):
    stk_code = stk_pools['code'][stk_pools.index[i]]
    stk_code = "%06d" % stk_code
    # 读入数据
    datapath = '../data/'+stk_code+'.csv'
    # 创建价格数据
    data = bt.feeds.GenericCSVData(
        dataname=datapath,
        fromdate=datetime.datetime(2018, 1, 1),
        todate=datetime.datetime(2021, 1, 1),
        nullvalue=0.0,
        dtformat="%Y-%m-%d",
        datetime=0,
        open=1,
        high=2,
        low=3,
        close=4,
        volume=5,
        openinterest=-1,
    )
    # 在cerebro中添加股票数据
    cerebro.adddata(data, name=stk_code)
# 设置启动资金
cerebro.broker.setcash(10000)
# 设置佣金为千分之一
cerebro.broker.setcommission(commission=0.001)
# 添加回测策略
cerebro.addstrategy(BollStrategy, p_oneplot=False)
# 遍历所有数据
cerebro.run()
# 打印最后结果
print("\033[31m Final Portfolio Value:%.2f  \033[0m"
      .format(cerebro.broker.getvalue()))
# 绘图
cerebro.plot()
